(use-modules
 ;; SRFI 64 for unit testing facilities
 (srfi srfi-64)
 ;; SRFI 8 for `receive` form
 (srfi srfi-8)
 ;; utils - the code to be tested
 (decision-tree)
 ;; Utilities for testing
 (utils test)
 ;; Dependencies for testing the code to be tested
 (dataset)
 (metrics)
 (pruning)
 (prediction)
 (data-point)
 (tree))


(define TEST-DATA
  (list #(2.771244718 1.784783929 0)
        #(1.728571309 1.169761413 0)
        #(3.678319846 2.81281357 0)
        #(3.961043357 2.61995032 0)
        #(2.999208922 2.209014212 0)
        #(7.497545867 3.162953546 1)
        #(9.00220326 3.339047188 1)
        #(7.444542326 0.476683375 1)
        #(10.12493903 3.234550982 1)
        #(6.642287351 3.319983761 1)))


(define PRECISION (expt 10 -9))


(test-begin "decision-tree-test")


(test-group
 "split-data"
 ;; split-data does not split correctly
 (test-equal "split-data-1"
   (list (list #(1.0 1.0 1.0 1.0 0)
               #(1.2 1.0 1.0 1.0 0)
               #(1.4 1.0 1.0 1.0 0))
         (list #(1.6 1.0 1.0 1.0 0)
               #(1.8 1.0 1.0 1.0 0)
               #(2.0 1.0 1.0 1.0 0)))
   (split-data (list #(1.0 1.0 1.0 1.0 0)
                     #(1.2 1.0 1.0 1.0 0)
                     #(1.4 1.0 1.0 1.0 0)
                     #(1.6 1.0 1.0 1.0 0)
                     #(1.8 1.0 1.0 1.0 0)
                     #(2.0 1.0 1.0 1.0 0))
               0
               1.5))


 ;; "split-data does not split correctly"
 (test-equal "split-data-2"
   (list (list #(1.0 1.0 1.0 1.0 0)
               #(1.4 1.0 1.0 1.0 0)
               #(1.8 1.0 1.0 1.0 0)
               #(2.0 2.0 1.0 1.0 0))
         (list #(1.2 4.0 1.0 1.0 0)
               #(1.6 3.0 1.0 1.0 0)))
   (split-data (list #(1.0 1.0 1.0 1.0 0)
                     #(1.2 4.0 1.0 1.0 0)
                     #(1.4 1.0 1.0 1.0 0)
                     #(1.6 3.0 1.0 1.0 0)
                     #(1.8 1.0 1.0 1.0 0)
                     #(2.0 2.0 1.0 1.0 0))
               1
               2.5)))


(test-group
 "get-best-split"
 (let ([test-data (list #(2.771244718 1.784783929 0)
                        #(1.728571309 1.169761413 0)
                        #(3.678319846 2.81281357 0)
                        #(3.961043357 2.61995032 0)
                        #(2.999208922 2.209014212 0)
                        #(7.497545867 3.162953546 1)
                        #(9.00220326 3.339047188 1)
                        #(7.444542326 0.476683375 1)
                        #(10.12493903 3.234550982 1)
                        #(6.642287351 3.319983761 1))]
       [feature-column-indices (list 0 1)]
       [label-column-index 2])
   ;; get-best-split does not give the best split
   (test-equal "get-best-split-1"
     ;; In the left branch the values of the first feature are all lower than
     ;; the values of the the first feature in the right branch.

     ;; In the right branch there is a value for the second feature, which is
     ;; lower than the values for that feature in the left branch, but all other
     ;; values of the feature in the right branch are higher than the ones in
     ;; the left branch, which makes the second feature an imperfect split
     ;; feature.

     ;; This means, that the best split is the one on the first feature.
     (make-split 0
                 6.642287351
                 (list
                  ;; left branch data
                  (list #(2.771244718 1.784783929 0)
                        #(1.728571309 1.169761413 0)
                        #(3.678319846 2.81281357 0)
                        #(3.961043357 2.61995032 0)
                        #(2.999208922 2.209014212 0))
                  ;; right branch data
                  (list #(7.497545867 3.162953546 1)
                        #(9.00220326 3.339047188 1)
                        #(7.444542326 0.476683375 1)
                        #(10.12493903 3.234550982 1)
                        #(6.642287351 3.319983761 1)))
                 0.0)

     (get-best-split test-data
                     feature-column-indices
                     label-column-index))))


(test-group
 "fit"
 (let ([test-data (list #(1.0 1.0 0)
                        #(1.2 1.0 0)
                        #(1.1 1.0 0)
                        #(1.4 1.0 0)
                        #(1.2 1.0 0)
                        #(1.2 1.0 0) ;;
                        #(2.3 1.0 1)
                        #(2.0 1.0 1)
                        #(2.3 1.0 1)
                        #(2.0 1.0 1)
                        #(2.3 1.0 1)
                        #(2.0 1.0 1)
                        #(2.4 1.0 1))]
       [feature-column-indices (list 0 1)]
       [label-column-index 2])
   (test-equal
       (let ([best-split (get-best-split test-data (list 0 1) 2)])
         (make-node test-data
                    (split-feature-index best-split)
                    (split-value best-split)
                    (make-leaf-node (list #(1.0 1.0 0)
                                          #(1.2 1.0 0)
                                          #(1.1 1.0 0)
                                          #(1.4 1.0 0)
                                          #(1.2 1.0 0)
                                          #(1.2 1.0 0)))
                    (make-leaf-node (list #(2.3 1.0 1)
                                          #(2.0 1.0 1)
                                          #(2.3 1.0 1)
                                          #(2.0 1.0 1)
                                          #(2.3 1.0 1)
                                          #(2.0 1.0 1)
                                          #(2.4 1.0 1)))))
     (fit #:train-data test-data
          #:feature-column-indices (list 0 1)
          #:label-column-index 2
          #:max-depth 2
          #:min-data-points 4
          #:min-data-points-ratio 0.02)))

 (let* ([test-data (list #(1.0 1.0 0)
                         #(1.2 1.0 0)
                         #(1.1 1.0 0)
                         #(1.4 1.0 0)
                         #(1.2 1.0 0)
                         #(1.2 1.0 0) ;;
                         #(2.3 1.1 0)
                         #(2.0 1.1 0)
                         #(2.3 1.0 1)
                         #(2.0 1.0 1)
                         #(2.3 1.0 1)
                         #(2.0 1.0 1)
                         #(2.4 1.0 1))]
        [best-split (get-best-split test-data (list 0 1) 2)])
   (test-equal
       (make-node test-data
                  (split-feature-index best-split)
                  (split-value best-split)
                  (make-leaf-node (list #(1.0 1.0 0)
                                        #(1.2 1.0 0)
                                        #(1.1 1.0 0)
                                        #(1.4 1.0 0)
                                        #(1.2 1.0 0)
                                        #(1.2 1.0 0)))
                  (let* ([subset (list #(2.3 1.1 0)
                                       #(2.0 1.1 0)
                                       #(2.3 1.0 1)
                                       #(2.0 1.0 1)
                                       #(2.3 1.0 1)
                                       #(2.0 1.0 1)
                                       #(2.4 1.0 1))]
                         [best-split (get-best-split subset (list 0 1) 2)])
                    (make-node subset
                               (split-feature-index best-split)
                               (split-value best-split)
                               (make-leaf-node (list #(2.3 1.0 1)
                                                     #(2.0 1.0 1)
                                                     #(2.3 1.0 1)
                                                     #(2.0 1.0 1)
                                                     #(2.4 1.0 1)))
                               (make-leaf-node (list #(2.3 1.1 0)
                                                     #(2.0 1.1 0))))))
     (fit #:train-data test-data
          #:feature-column-indices (list 0 1)
          #:label-column-index 2
          #:max-depth 3
          #:min-data-points 2
          #:min-data-points-ratio 0.02)))

 (let* ([test-data (list #(2.3 1.1 0)
                         #(2.0 1.1 0)
                         #(2.3 1.0 1)
                         #(2.0 1.0 1)
                         #(2.3 1.0 1)
                         #(2.0 1.0 1)
                         #(2.4 1.0 1))]
        [best-split (get-best-split test-data (list 0 1) 2)])
   (test-equal
       (make-node test-data
                  (split-feature-index best-split)
                  (split-value best-split)
                  (make-leaf-node (list #(2.3 1.0 1)
                                        #(2.0 1.0 1)
                                        #(2.3 1.0 1)
                                        #(2.0 1.0 1)
                                        #(2.4 1.0 1)))
                  (make-leaf-node (list #(2.3 1.1 0)
                                        #(2.0 1.1 0))))
     (fit #:train-data test-data
          #:feature-column-indices (list 0 1)
          #:label-column-index 2
          #:max-depth 3
          #:min-data-points 2
          #:min-data-points-ratio 0.02))))


(test-group
 "column-uniform?"
 (test-assert "column-uniform? of empty column should be true"
   (column-uniform? empty-dataset =))

 (test-assert "column-uniform? of uniform column should result in true -- 1"
   (column-uniform? (list 1 1 1) =))

 (test-assert "column-uniform? of uniform column should result in true -- 2"
   (column-uniform?
    (dataset-get-col
     (list #(1.0 1.0 0)
           #(1.2 1.0 0)
           #(1.1 1.0 0)
           #(1.4 1.0 0)
           #(1.2 1.0 0)
           #(1.2 1.0 0))
     2)
    =))

 (test-assert "column-uniform? of non-uniform column should result in false"
   (not
    (column-uniform? (list 1 2 3) =))))


(test-group
 "dataset-partition"
 (test-equal "dataset-partition should split at given value of specified column"
   (list (list #(2.3 1.0 0)
               #(2.0 1.0 0)
               #(2.3 1.0 0)
               #(2.0 1.0 0)
               #(2.4 1.0 0))
         (list #(2.3 1.1 1)
               #(2.0 1.1 1)))
   (receive (matching not-matching)
       (dataset-partition (lambda (data-point)
                            (= (data-point-get-col data-point 2) 0))
                          (list #(2.3 1.1 1)
                                #(2.0 1.1 1)
                                #(2.3 1.0 0)
                                #(2.0 1.0 0)
                                #(2.3 1.0 0)
                                #(2.0 1.0 0)
                                #(2.4 1.0 0)))
     (list matching not-matching))))


(test-group
 "cross-validation-split"
 (test-equal
     (list '(6 19 13 0 10)
           '(2 16 3 17 4)
           '(11 8 7 14 1)
           '(5 12 18 9 15))
   (cross-validation-split '(0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19)
                           4
                           #:random-seed 12345)))


(test-group
 "leave-one-out-k-folds"
 (test-equal
     (list (list #(1 1)
                 #(1 1)
                 #(1 1)
                 #(1 1))
           (list #(2 2)
                 #(2 2)
                 #(2 2)
                 #(2 2))
           (list #(4 4)
                 #(4 4)
                 #(4 4)
                 #(4 4)))
   (leave-one-out-k-folds (list (list #(1 1)
                                      #(1 1)
                                      #(1 1)
                                      #(1 1))
                                (list #(2 2)
                                      #(2 2)
                                      #(2 2)
                                      #(2 2))
                                (list #(3 3)
                                      #(3 3)
                                      #(3 3)
                                      #(3 3))
                                (list #(4 4)
                                      #(4 4)
                                      #(4 4)
                                      #(4 4)))
                          (list #(3 3)
                                #(3 3)
                                #(3 3)
                                #(3 3)))))


(test-group
 "select-min-cost-split"

 (test-equal "select-min-cost-split selects best of 3 splits"
   (make-split 2 9.78 '() 0.0)
   (select-min-cost-split (make-split 0 1.1 '() 2.0)
                          (make-split 1 2.67 '() 1.0)
                          (make-split 2 9.78 '() 0.0))))


(test-group
 "evaluate-algorithm "
 (test-equal
     4
   (length
    (evaluate-algorithm
     #:dataset TEST-DATA
     #:n-folds 4
     #:feature-column-indices (list 0 1)
     #:label-column-index 2
     #:max-depth 3
     #:min-data-points 4
     #:min-data-points-ratio 0.02
     #:min-impurity-split (expt 10 -7)
     #:stop-at-no-impurity-improvement #t
     #:random-seed 0)))
 ;; TODO: real test cose
 )


(test-end "decision-tree-test")
